# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import copy
import json
import math
import os
import sys
import time
from collections import OrderedDict
from contextlib import nullcontext
from datetime import datetime, timedelta
from functools import partial
import types
from typing import List, Optional, Tuple, Union, Dict
from torch import Tensor

from torch.distributed._composable.fsdp import fully_shard
from torch.distributed._composable import checkpoint
from torch.distributed._composable.fsdp import MixedPrecisionPolicy

import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import ReduceOp
import torch.nn.functional as F
import torch.distributed.checkpoint as dcp
import torch.nn as nn

from torch.nn import CrossEntropyLoss

from liger_kernel.ops.fused_linear_cross_entropy import (
    LigerFusedLinearCrossEntropyFunction,
)
from liger_kernel.transformers.fused_linear_cross_entropy import (
    LigerFusedLinearCrossEntropyLoss,
)
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss

from transformers.modeling_outputs import (BaseModelOutputWithPast,
                                           CausalLMOutputWithPast,
                                           MoeCausalLMOutputWithPast,
                                           SequenceClassifierOutputWithPast)

from accelerate.utils import set_module_tensor_to_device
from datasets import load_from_disk
from mmengine import mkdir_or_exist
from mmengine.dist import infer_launcher, init_dist
from mmengine.runner import set_random_seed
from mmengine.utils import get_git_hash
from mmengine.utils.dl_utils import collect_env
from mmengine import MessageHub
from tabulate import tabulate
from torch.distributed._tensor import Shard, distribute_tensor, Replicate, DTensor
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \
    apply_activation_checkpointing
from torch.distributed.checkpoint.state_dict import (StateDictOptions,
                                                     get_state_dict,
                                                     set_state_dict)
from torch.distributed.device_mesh import init_device_mesh, DeviceMesh
from torch.distributed.fsdp import CPUOffload
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import _or_policy
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from torch.profiler import ProfilerActivity, profile, record_function
from torch.utils.data import ConcatDataset, DataLoader
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.modeling_utils import PreTrainedModel, load_state_dict
from transformers.utils import (SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
                                is_safetensors_available)
from transformers.utils.import_utils import (is_flash_attn_2_available,
                                             is_torch_sdpa_available)

from xtuner._lite import (AutoTokenizer, get_device, get_logger,
                          get_torch_device_module)
from xtuner._lite.accelerate import (LORA_TARGET_MAP, dispatch_hf_code, LoadWoInit,
                                     packed_sequence, varlen_attn_is_available, profile_time_and_memory)
from xtuner._lite.algorithms.sft import SftCollator, SftTokenizeFunction
from xtuner._lite.chat import CHAT_TEMPLATE_MAP
from xtuner._lite.datasets import (DATASET_CLS_MAP, OPENAI_CONVERT_MAP,
                                   SoftPackDataset, HardPackDataset, load_datasets)
from xtuner._lite.parallel import (LengthGroupedSampler, ParallelSampler,
                                   get_dp_mesh, get_sp_mesh,
                                   pad_for_sequence_parallel,
                                   reduce_sequence_parallel_loss,
                                   setup_parallel, split_for_sequence_parallel)

from xtuner._lite.parallel import (ParallelSampler, get_dp_mesh, get_fsdp_mesh,
                                   get_sp_mesh, get_tp_mesh, get_world_mesh, get_same_data_mesh,
                                   pad_for_sequence_parallel, setup_parallel,
                                   reduce_sequence_parallel_loss,
                                   split_for_sequence_parallel,
                                   get_ep_mesh, get_experts_fsdp_mesh)

from xpuyu.datasets import (MultiStreamingDataset, PretrainTokenizeFunction,
                            Streaming, StreamingDataset)
from xpuyu.accelerate import dispatch_hf_code
from xpuyu.parallel.megatron import megatron_internlm3_moe_casual


logger = get_logger()

DEVICE = get_device()
DEVICE_MODULE = get_torch_device_module()

SUPPORT_DATA_FORMATS = OPENAI_CONVERT_MAP.keys()

def log_format(rank, debug=False):

    sp_rank = get_sp_mesh().get_local_rank()
    dp_rank = get_dp_mesh().get_local_rank()
    tp_rank = get_tp_mesh().get_local_rank()
    fsdp_rank = get_fsdp_mesh().get_local_rank()

    formatter = f'[XTuner][RANK {rank}][DP {dp_rank}][SP {sp_rank}][TP {tp_rank}]'
    formatter += '[{time:YYYY-MM-DD HH:mm:ss}][<level>{level}</level>]'

    if debug:
        formatter += '[<cyan>{name}</cyan>:'
        formatter += '<cyan>{function}</cyan>:'
        formatter += '<cyan>{line}</cyan>]'

    formatter += ' <level>{message}</level>'
    return formatter


def parse_args():
    parser = argparse.ArgumentParser(description='Train LLM')

    model_args = parser.add_argument_group('model', 'Group 1 description')
    model_args.add_argument('--llm', help='config file name or path.')
    model_args.add_argument(
        '-t',
        '--tokenizer',
        help=('repo id or local path of the tokenizer. '
              'Defaults to the same as `model`'))
    model_args.add_argument('--load-pretrain', action='store_true')

    model_args.add_argument(
        '--dtype',
        default='auto',
        choices=['fp16', 'bf16', 'auto'],
        help=("the dtype of the model forward. When set to 'auto', it will "
              'automatically determine whether bf16 is available, '
              'prioritizing the use of bf16.'))
    model_args.add_argument(
        '--selective-recompute',
        default=1.0,
        type=float,
        help=('the ratio of re-computation for transforemer layers. '
              'The maximum is 1; the larger the value, the less memory '
              'required for training. The default is 1, meaning all layers '
              'need to be re-computated.'))
    model_args.add_argument('--cpu-offload', action='store_true', help=(''))
    model_args.add_argument(
        '--shard-strategy',
        default='full',
        choices=['full', 'hybrid', 'zero2', 'no', 'hybrid_zero2'],
        help=('The sharding strategy to be used for distributed training.'))

    custom_model_args = parser.add_argument_group('model',
                                                  'Custom model structure')
    custom_model_args.add_argument('--hidden-size', type=int, default=None)
    custom_model_args.add_argument(
        '--num-attention-heads', type=int, default=None)
    custom_model_args.add_argument(
        '--num-key-value-heads', type=int, default=None)
    custom_model_args.add_argument(
        '--intermediate-size', type=int, default=None)
    custom_model_args.add_argument(
        '--num-hidden-layers', type=int, default=None)
    custom_model_args.add_argument(
        '--vocab-size', type=int, default=None)
    custom_model_args.add_argument(
        '--n-shared-experts', type=int, default=None)
    custom_model_args.add_argument(
        '--num-experts-per-tok', type=int, default=None)
    custom_model_args.add_argument(
        '--num-routed-experts', type=int, default=None)
    custom_model_args.add_argument(
        '--head-dim', type=int, default=None)
    custom_model_args.add_argument(
        '--aux-loss-alpha', type=float, default=None)

    data_args = parser.add_argument_group('data', 'Group 1 description')
    data_args.add_argument(
        '--datasets',
        nargs='*',
        help=('repo id or local path or dir of the datasets. For repo ids, '
              'the `dset-sources` needs to be appropriately set to '
              '`modelscope` or `huggingface`. For local dir, all json and '
              'jsonl files will be loaded by default. The type of loaded '
              'files can be controlled by setting `dset-file-type`'))
    data_args.add_argument(
        '--weights',
        type=str,
        default='tools/data_weight_h.json',
    )
    data_args.add_argument(
        '--max-length',
        type=int,
        default=2048,
        help=('the maximum length of each piece of data, any excess will be '
              'truncated.'))
    data_args.add_argument(
        '--num-proc',
        type=int,
        default=8,
        help='how many subprocesses to use for data mapping.')

    data_args.add_argument('--val-datasets', default=None)
    data_args.add_argument(
        '--val-dset-cache-dir',
        help=('the cache dir of the loaded datasets. When the `datasets` is '
              'set, the loaded datasets will be cached to this dir. If the '
              '`datasets` are not set, the cached dataset in this dir will be '
              'loaded.'))
    data_args.add_argument(
        '--val-dset-from-cache',
        action='store_true',
        help=('Load data directly from `dset-cache-dir`. This can save time '
              'on online tokenization, but if the tokenizer changed, '
              'recaching is needed.'))
    data_args.add_argument(
        '--val-mirco-batch-size',
        type=int,
        default=4,
    )

    dist_args = parser.add_argument_group('dist', 'Group 1 description')
    dist_args.add_argument('--sp-size', type=int, default=1, help='')
    dist_args.add_argument('--ep-size', type=int, default=1, help='')

    optim_args = parser.add_argument_group('optimizer', 'Group 1 description')
    optim_args.add_argument(
        '--mirco-batch-size',
        type=int,
        default=1,
        help='batch size for each forward + backward pass')
    optim_args.add_argument(
        '--global-batch-size',
        type=int,
        default=16,
        help='batch size for each optimizer step')
    optim_args.add_argument(
        '--lr',
        '--learning-rate',
        default=4e-5,
        type=float,
        help='the dir to save logs and models')
    optim_args.add_argument('--lr-min', default=1.5e-6, type=float)
    optim_args.add_argument(
        '--wd', default=0.01, type=float, help='weight decay.')
    optim_args.add_argument(
        '--max-grad-norm', default=1, type=float, help='gradient clipping')
    optim_args.add_argument(
        '--total-steps',
        default=30000,
        type=int,
        help=('the format of each dataset; it can accept one or the same '
              'number of args as the number of `datasets`, with one arg '
              'indicating that all datasets are the same format.'),
    )
    optim_args.add_argument(
        '--warmup-ratio',
        default=0.03,
        type=float,
        help=('the proportion of training steps for learning rate warm-up in '
              'relation to the total training steps.'))

    parser.add_argument('-c', '--config', default=None)
    parser.add_argument(
        '--work-dir',
        default='work_dirs',
        help='the dir to save logs and models')
    parser.add_argument(
        '--checkpoint-interval',
        default=-1,
        type=float,
        help=('how many steps to save a checkpoint; it can be a floating '
              'point number less than 1, or an integer greater than or equal '
              "to 1. When it's a floating point, it will be multiplied by the "
              'total number of training steps.'))
    parser.add_argument('--save-fused-checkpoint', action='store_true')
    parser.add_argument(
        '--checkpoint-drop-optimizer',
        action='store_true',
        help=('only model parameters are saved when saving a checkpoint. '
              'This can significantly reduce the size of checkpoint files, '
              'but the saved checkpoints cannot be resumed.'))
    parser.add_argument('--log-interval', default=1, type=int)
    parser.add_argument('--val-interval', default=5000, type=int)
    parser.add_argument(
        '--resume',
        type=str,
        default=None,
        help='specify checkpoint path to be resumed from.')
    parser.add_argument(
        '--seed', type=int, default=0, help='Random seed for the training')
    parser.add_argument(
        '--debug', action='store_true', help='Set logger level to `DEBUG`')
    parser.add_argument('--checkpoint-is-fused', action='store_true')
    args = parser.parse_args()
    return args


def is_interval(step, total_steps, interval):
    return (step + 1) % interval == 0 or (step + 1) == total_steps


def map_meta_modules(model, meta_model):
    modules = {name: mod for name, mod in model.named_modules()}
    meta_module_map = {
        mod: modules[name]
        for name, mod in meta_model.named_modules()
    }
    return meta_module_map


def build_streamings(ds_path_list, ds_weights=None):

    if ds_weights:
        with open(ds_weights) as f:
            ds_weights = json.load(f)
        weights = []
    streamings = []
    for ds_path in ds_path_list:
        if os.path.isdir(ds_path):
            paths = []
            for dirpath, dirnames, filenames in os.walk(
                    ds_path, followlinks=True):
                for filename in filenames:
                    if not filename.endswith('.jsonl'):
                        continue
                    paths.append(os.path.join(dirpath, filename))
        elif os.path.isfile(ds_path):
            paths = [ds_path] if ds_path.endswith('.jsonl') else []
        else:
            raise NotImplementedError
        for path in paths:
            if ds_weights:
                w = None
                for key, val in ds_weights.items():
                    if key in path:
                        w = val
                        break
                if w is None:
                    logger.info(
                        f'file {path} is dropped as it is not in {ds_weights}.'
                    )
                    continue
                weights.append(w)
            file_size = os.path.getsize(path)
            if file_size > 1000:
                streamings.append(Streaming(path, max_epoch=1))
    if ds_weights is None:
        weights = [1] * len(streamings)

    logger.info(f'Found {len(streamings)} streaming datasets.')

    return streamings, weights


def cal(llm, cfg):
    numel_act = 0
    numel_total = 0
    numel_moe = 0
    numel_wo_moe = 0
    numel_attn = 0
    numel_act_moe = 0
    for name, param in llm.named_parameters():
        if 'expert' in name:
            numel_moe += param.numel()
        else:
            numel_wo_moe += param.numel()
        if '.experts.' in name:
            numel_act += param.numel(
            ) * cfg.num_experts_per_tok / cfg.num_routed_experts
        else:
            numel_act += param.numel()
        if 'attention' in name or 'self_attn' in name:
            numel_attn += param.numel()
        numel_total += param.numel()
        # print(name, param.numel()/1e9)
    print(
        f'Total act param: {numel_act / 1e9}, Total param: {numel_total / 1e9}, MoE param: {numel_moe / 1e9}, '
        f'Other param: {numel_wo_moe / 1e9}, Attn param: {numel_attn / 1e9}, MoE act param: {(numel_act - numel_wo_moe) / 1e9}'
    )


def build_llm_model(args, config, dtype=torch.float32):
    if args.load_pretrain:
        with LoadWoInit():
            llm = AutoModelForCausalLM.from_pretrained(
                args.llm,
                trust_remote_code=True,
                torch_dtype=dtype,
                attn_implementation=config.attn_implementation)
    else:
        cfg = copy.deepcopy(config)

        if args.hidden_size is not None:
            cfg.hidden_size = args.hidden_size
        if args.intermediate_size is not None:
            cfg.intermediate_size = args.intermediate_size
        if args.num_attention_heads is not None:
            cfg.num_attention_heads = args.num_attention_heads
        if args.num_key_value_heads is not None:
            cfg.num_key_value_heads = args.num_key_value_heads
        if args.num_hidden_layers is not None:
            cfg.num_hidden_layers = args.num_hidden_layers
        if args.vocab_size is not None:
            cfg.vocab_size = args.vocab_size
        # qwen 用的是 shared_expert_intermediate_size
        if args.n_shared_experts is not None:
            cfg.n_shared_experts = args.n_shared_experts
        if args.num_experts_per_tok is not None:
            cfg.num_experts_per_tok = args.num_experts_per_tok
        if args.num_routed_experts is not None:
            cfg.num_routed_experts = args.num_routed_experts
        if args.head_dim is not None:
            cfg.head_dim = args.head_dim
        if args.aux_loss_alpha is not None:
            cfg.aux_loss_alpha = args.aux_loss_alpha

        llm = AutoModelForCausalLM.from_config(
            config=cfg,
            trust_remote_code=True,
            torch_dtype=config.torch_dtype,
            attn_implementation='flash_attention_2')
        
    if dist.get_rank() == 0:
        cal(llm, llm.config)

    # Ensure all numerical values in the optimizer are fp32.
    # FSDP will use low precision during forward.
    llm.to(dtype)
    llm.config.use_cache = False

    return llm


# from xpuyu.modelings.internlm_moe.modeling_internlm3_moe import load_balancing_loss_func

# def internlm3_moe_forward_fused_linear_ce(
#     self,
#     input_ids: torch.LongTensor = None,
#     attention_mask: Optional[torch.Tensor] = None,
#     position_ids: Optional[torch.LongTensor] = None,
#     past_key_values: Optional[List[torch.FloatTensor]] = None,
#     inputs_embeds: Optional[torch.FloatTensor] = None,
#     labels: Optional[torch.LongTensor] = None,
#     use_cache: Optional[bool] = None,
#     output_attentions: Optional[bool] = None,
#     output_hidden_states: Optional[bool] = None,
#     output_router_logits: Optional[bool] = None,
#     return_dict: Optional[bool] = None,
#     cache_position: Optional[torch.LongTensor] = None,
#     num_logits_to_keep: int = 0,
#     **loss_kwargs,
# ):
#     output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
#     output_router_logits = (
#         output_router_logits if output_router_logits is not None else self.config.output_router_logits
#     )
#     output_hidden_states = (
#         output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
#     )
#     return_dict = return_dict if return_dict is not None else self.config.use_return_dict

#     # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
#     outputs = self.model(
#         input_ids=input_ids,
#         attention_mask=attention_mask,
#         position_ids=position_ids,
#         past_key_values=past_key_values,
#         inputs_embeds=inputs_embeds,
#         use_cache=use_cache,
#         output_attentions=output_attentions,
#         output_hidden_states=output_hidden_states,
#         output_router_logits=output_router_logits,
#         return_dict=return_dict,
#         cache_position=cache_position,
#     )

#     hidden_states = outputs[0]

#     loss = None
#     logits = None

#     if self.training and (labels is not None):
#         shift_hidden_states = hidden_states[..., :-1, :].contiguous()
#         shift_labels = labels[..., 1:].contiguous()

#         # flatten tokens
#         shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
#         shift_labels = shift_labels.view(-1)

#         lce = LigerFusedLinearCrossEntropyLoss()
#         loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)

#     else:
#         logits = self.lm_head(hidden_states)
#         if labels is not None:
#             # Upcast to float if we need to compute the loss to avoid potential precision issues
#             logits = logits.float()
#             # Shift so that tokens < n predict n
#             shift_logits = logits[..., :-1, :].contiguous()
#             shift_labels = labels[..., 1:].contiguous()
#             # Flatten the tokens
#             loss_fct = CrossEntropyLoss()
#             shift_logits = shift_logits.view(-1, self.config.vocab_size)
#             shift_labels = shift_labels.view(-1)
#             # Enable model parallelism
#             shift_labels = shift_labels.to(shift_logits.device)
#             loss = loss_fct(shift_logits, shift_labels)

#     aux_loss = None
#     if output_router_logits:
#         aux_loss = load_balancing_loss_func(
#             outputs.router_logits if return_dict else outputs[-1],
#             self.num_routed_experts,
#             self.num_experts_per_tok,
#             attention_mask,
#         )
#         if labels is not None:
#             loss += self.router_aux_loss_coef * aux_loss.to(loss.device)  # make sure to reside in the same device

#     if not return_dict:
#         output = (logits,) + outputs[1:]
#         if output_router_logits:
#             output = (aux_loss,) + output
#         return (loss,) + output if loss is not None else output

#     return MoeCausalLMOutputWithPast(
#         loss=loss,
#         aux_loss=aux_loss,
#         logits=logits,
#         past_key_values=outputs.past_key_values,
#         hidden_states=outputs.hidden_states,
#         attentions=outputs.attentions,
#         router_logits=outputs.router_logits,
#     )


@torch.no_grad()
def reduce_ep_grad(llm, ep_size):
    for module in llm.modules():
        if type(module).__name__ == 'GroupedLinear':
            if module.w1w3.grad is not None:
                module.w1w3.grad.div_(ep_size)
            if module.w2.grad is not None:
                module.w2.grad.div_(ep_size)


from torch.nn.utils.clip_grad import _no_grad
from torch.utils._foreach_utils import (
    _device_has_foreach_support,
    _group_tensors_by_device_and_dtype,
    _has_foreach_support,
)


@_no_grad
def clip_grad_norm_(
    moe_params,
    non_moe_params,
    experts_fsdp_mesh,
    max_norm: float,
    norm_type: float = 2.0,
    error_if_nonfinite: bool = False,
    foreach: Optional[bool] = None,
) -> torch.Tensor:
    if isinstance(moe_params, torch.Tensor):
        moe_params = [moe_params]
    if isinstance(non_moe_params, torch.Tensor):
        non_moe_params = [non_moe_params]
    moe_grads = [p.grad for p in moe_params if p.grad is not None]
    non_moe_grads = [p.grad for p in non_moe_params if p.grad is not None]
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    if len(moe_grads) + len(non_moe_grads) == 0:
        return torch.tensor(0.0)
    first_device = moe_grads[0].device
    grouped_moe_grads: Dict[
        Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
    ] = _group_tensors_by_device_and_dtype(
        [moe_grads]
    )
    grouped_non_moe_grads: Dict[
        Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
    ] = _group_tensors_by_device_and_dtype(
        [non_moe_grads]
    )
    moe_norms: List[Tensor] = []
    non_moe_norms: List[Tensor] = []

    for (device, _), ([device_grads], _) in grouped_moe_grads.items():  # type: ignore[assignment]
        if (foreach is None and _has_foreach_support(device_grads, device)) or (
            foreach and _device_has_foreach_support(device)
        ):
            moe_norms.extend(torch._foreach_norm(device_grads, norm_type))
        elif foreach:
            raise RuntimeError(
                f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
            )
        else:
            moe_norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
    
    for (device, _), ([device_grads], _) in grouped_non_moe_grads.items():  # type: ignore[assignment]
        if (foreach is None and _has_foreach_support(device_grads, device)) or (
            foreach and _device_has_foreach_support(device)
        ):
            non_moe_norms.extend(torch._foreach_norm(device_grads, norm_type))
        elif foreach:
            raise RuntimeError(
                f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
            )
        else:
            non_moe_norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
    
    local_sharded_moe_norm = torch.linalg.vector_norm(
        torch.stack([norm.to_local().to(first_device) for norm in moe_norms]), norm_type, dtype=torch.float32
    )
    local_sharded_non_moe_norm = torch.linalg.vector_norm(
        torch.stack([norm.to_local().to(first_device) for norm in non_moe_norms]), norm_type, dtype=torch.float32
    )

    if norm_type == 2:
        total_sharded_moe_norm = local_sharded_moe_norm**norm_type
        total_sharded_non_moe_norm = local_sharded_non_moe_norm**norm_type
        dist.all_reduce(total_sharded_moe_norm)
        dist.all_reduce(total_sharded_non_moe_norm, group=experts_fsdp_mesh.get_group(mesh_dim=0))
        total_norm = (total_sharded_moe_norm + total_sharded_non_moe_norm) ** 0.5
    else:
        raise NotImplementedError

    if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
        raise RuntimeError(
            f"The total norm of order {norm_type} for gradients from "
            "`parameters` is non-finite, so it cannot be clipped. To disable "
            "this error and scale the gradients by the non-finite norm anyway, "
            "set `error_if_nonfinite=False`"
        )
    clip_coef = max_norm / (total_norm + 1e-6)
    # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
    # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
    # when the gradients do not reside in CPU memory.
    clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
    for (device, _), ([device_grads], _) in grouped_moe_grads.items():  # type: ignore[assignment]
        if (foreach is None and _has_foreach_support(device_grads, device)) or (
            foreach and _device_has_foreach_support(device)
        ):
            torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
        elif foreach:
            raise RuntimeError(
                f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
            )
        else:
            clip_coef_clamped_device = clip_coef_clamped.to(device)
            for g in device_grads:
                g.mul_(clip_coef_clamped_device)
    
    for (device, _), ([device_grads], _) in grouped_non_moe_grads.items():  # type: ignore[assignment]
        if (foreach is None and _has_foreach_support(device_grads, device)) or (
            foreach and _device_has_foreach_support(device)
        ):
            torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
        elif foreach:
            raise RuntimeError(
                f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
            )
        else:
            clip_coef_clamped_device = clip_coef_clamped.to(device)
            for g in device_grads:
                g.mul_(clip_coef_clamped_device)

    return total_norm


@logger.catch
def main(args):
    ###########################################################################
    #                           1. Environment                                #
    ###########################################################################

    setup_parallel(sp_size=args.sp_size, tp_size=1, ep_size=1)
    set_random_seed(args.seed)

    dp_mesh = get_dp_mesh()
    sp_mesh = get_sp_mesh()
    ep_mesh = get_ep_mesh()
    experts_fsdp_mesh = get_experts_fsdp_mesh()
    world_mesh = get_world_mesh()

    if ep_mesh.size() > 1:
        raise NotImplementedError

    dp_size = dp_mesh.size()
    sp_size = sp_mesh.size()
    world_size = world_mesh.size()

    if args.global_batch_size < dp_size or args.global_batch_size % dp_size:
        raise ValueError(f'The `global_batch_size`({args.global_batch_size}) '
                         'should be divisible by the '
                         f'world_size({world_size}).')

    if (args.global_batch_size / dp_size) % args.mirco_batch_size:
        raise ValueError(f'The `global_batch_size`({args.global_batch_size}) '
                         f'should be divisible by the world_size({world_size})'
                         f' * `mirco_batch_size`({args.mirco_batch_size})')

    rank = dist.get_rank()
    timestamp = datetime.now().strftime('%Y%m%d%H%M%S')

    objects = [timestamp]
    dist.broadcast_object_list(objects, src=0)
    timestamp = objects[0]

    args.work_dir = os.path.join(args.work_dir, timestamp)
    mkdir_or_exist(args.work_dir)

    log_file = os.path.join(args.work_dir, f'rank{rank}.log')
    vis_data_file = os.path.join(args.work_dir, 'vis_data.jsonl')

    # Change the log format printed in the terminal
    lvl = 'DEBUG' if args.debug else 'INFO'
    logger.add(sys.stderr, level=lvl, format=log_format(rank, args.debug))
    # Change the format saved in the log file
    logger.add(log_file, format=log_format(rank), backtrace=True, catch=True)

    logger.info(args)
    if rank == 0:
        env = collect_env()
        import transformers

        import xtuner
        env['Transformers'] = transformers.__version__
        env['XTuner'] = f'{xtuner.__version__}+{get_git_hash(digits=6)}'
        runtime_env = OrderedDict()
        runtime_env.update(env)
        runtime_env['Seed'] = args.seed
        runtime_env['World Size'] = world_size
        runtime_env['DP Size'] = dp_size
        runtime_env['SP Size'] = sp_size
        # runtime_env['Distributed launcher'] = dist_launcher

        runtime_env_info = '\n    ' + '\n    '.join(
            f'{k}: {v}' for k, v in runtime_env.items())
        dash_line = '-' * 60
        logger.info('\n' + dash_line + '\nRuntime environment:' +
                    runtime_env_info + '\n' + dash_line + '\n')
    # -------------------    Environment  End  ------------------------------ #

    ###########################################################################
    #                     2. Dataset & Dataloader                             #
    ###########################################################################

    start_load_data_t = time.time()

    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer if args.tokenizer else args.llm,
        trust_remote_code=True,
        padding_side='right',
        use_fast=False
    )

    pack_batch = is_flash_attn_2_available()
    collator = SftCollator(pack_batch=pack_batch)

    tokenize_fn = PretrainTokenizeFunction(tokenizer)
    train_streamings, train_weight = build_streamings(args.datasets,
                                                      args.weights)
    train_dataset = MultiStreamingDataset(
        train_streamings,
        train_weight,
        args.max_length,
        tokenize_fn,
        seed=args.seed,
        dp_rank=rank,
        dp_world_size=dp_size,
        pack='hard')

    train_dataloader = DataLoader(
        dataset=train_dataset,
        batch_size=args.mirco_batch_size,
        num_workers=1,
        collate_fn=collator,
        # Ensure to round up or drop last based on the `global_batch_size`,
        # if you want to replace a custom sampler.
    )

    load_data_cost_time = time.time() - start_load_data_t
    logger.info(f'[Dataset & Dataloader] Cost {load_data_cost_time:.2f}s')

    # -------------------    Dataset & Dataloader  End  --------------------- #

    ###########################################################################
    #                          3. FSDP                                        #
    ###########################################################################

    start_model_t = time.time()

    if args.dtype == 'auto':
        args.dtype = 'bf16' if torch.cuda.is_bf16_supported() else 'fp16'

    if args.dtype == 'fp16':
        dtype = torch.float16
        autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype)
        scaler = ShardedGradScaler()
    elif args.dtype == 'bf16':
        if torch.cuda.is_bf16_supported():
            dtype = torch.bfloat16
            autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype)
            scaler = None
        else:
            raise RuntimeError('The device does not support `bf16`, '
                               'please set `dtype` to `fp16`.')
    else:
        raise RuntimeError('`dtype` only supports `fp16`, `bf16` or `auto`, '
                           f'but found {args.dtype}.')

    llm_cfg = AutoConfig.from_pretrained(args.llm, trust_remote_code=True)
    if is_flash_attn_2_available():
        llm_cfg.attn_implementation = 'flash_attention_2'
    elif is_torch_sdpa_available():
        llm_cfg.attn_implementation = 'sdpa'

    llm_cfg.use_cache = False
    llm_cfg.torch_dtype = dtype

    # Only load parameters on rank 0 to avoid each rank repeatedly loading the
    # same model into the CPU, wasting memory
    xtuner_load_timeout = timedelta(minutes=60)
    group_gloo = dist.new_group(backend='gloo', timeout=xtuner_load_timeout)

    if rank == 0:
        with torch.device('cpu'):
            rank0_llm = build_llm_model(args, llm_cfg, dtype)
    else:
        rank0_llm = None
    
    dist.monitored_barrier(group=group_gloo, timeout=xtuner_load_timeout)
    logger.info('after barrier')

    with torch.device('meta'):
        llm = build_llm_model(
            args,
            llm_cfg,
            dtype=torch.float32)
        dispatch_hf_code(llm)
        for module in llm.modules():
            for p_name, param in module.named_parameters(recurse=False):
                if param.requires_grad:
                    param_fp32 = torch.nn.Parameter(
                        param.to(dtype=torch.float32))
                    setattr(module, p_name, param_fp32)
    
    # logger.info('dispatch internlm3_moe_forward_fused_linear_ce')
    # llm.forward = types.MethodType(internlm3_moe_forward_fused_linear_ce, llm)

    mp_policy = MixedPrecisionPolicy(param_dtype=dtype, reduce_dtype=dtype)

    with profile_time_and_memory('[Parallelize LLM]'):
        megatron_internlm3_moe_casual(
            llm,
            rank0_llm,
            experts_fsdp_mesh=experts_fsdp_mesh,
            ep_mesh=ep_mesh,
            mp_policy=mp_policy,
            recompute_ratio=args.selective_recompute,
            reshard_after_forward=True)
        
        llm.train()
    
    if rank == 0:
        logger.info(llm)
    
    # --------------------------    FSDP  End  ------------------------------ #

    ###########################################################################
    #                      4. Optimizer & Scheduler                           #
    ###########################################################################

    requried_grad_params = [
        param for param in llm.parameters() if param.requires_grad
    ]
    requried_grad_moe_params = []
    requried_grad_non_moe_params = []
    for name, param in llm.named_parameters():
        if not param.requires_grad:
            continue
        if '.experts.' in name:
            requried_grad_moe_params.append(param)
        else:
            requried_grad_non_moe_params.append(param)

    optimizer = AdamW(
        requried_grad_params,
        lr=args.lr,
        weight_decay=args.wd,
        betas=(0.9, 0.95),
        eps=1e-8)

    global_batch_size = args.global_batch_size
    mirco_batch_size = args.mirco_batch_size

    # `iter` means once forward+backward
    # `step` means once optimizer step
    # `iters_per_step` means gradient accumulative counts
    total_steps = args.total_steps
    iters_per_step = global_batch_size // mirco_batch_size // dp_size
    steps_per_epoch = total_steps

    if args.checkpoint_interval == -1:
        checkpoint_interval = total_steps
    elif args.checkpoint_interval < 1:
        checkpoint_interval = int(total_steps * args.checkpoint_interval)
    else:
        checkpoint_interval = int(args.checkpoint_interval)

    if args.warmup_ratio < 1:
        warmup_steps = int(args.warmup_ratio * total_steps)
    else:
        warmup_steps = int(args.warmup_ratio)

    def warmup_fn(x):
        return x / warmup_steps if x < warmup_steps else 1

    warmup_scheduler = LambdaLR(optimizer, warmup_fn)

    cosine_scheduler = CosineAnnealingLR(
        optimizer, T_max=total_steps - warmup_steps, eta_min=args.lr_min)

    start_step = 0

    # ----------------    Optimizer & Scheduler End   ----------------------- #

    ###########################################################################
    #                          5. Training                                    #
    ###########################################################################

    start_train_t = time.time()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    max_memory = torch.cuda.max_memory_allocated()
    logger.info('[Train] Begin Train Loop. The current GPU memory is '
                f'{(max_memory / 1024**3):.1f}GB')
    
    data_iterator = iter(train_dataloader)
    start_train_t2 = time.time()

    for step in range(start_step, total_steps):

        if step <= warmup_steps:
            warmup_scheduler.step()
            cur_lr = warmup_scheduler.get_last_lr()[0]
        else:
            cosine_scheduler.step()
            cur_lr = cosine_scheduler.get_last_lr()[0]

        torch.cuda.reset_peak_memory_stats()

        step_loss = 0
        step_data_time = 0
        step_start_t = time.time()
        step_consumed_tokens = 0

        for inner_iter in range(iters_per_step):

            _data_start_t = time.time()
            data = next(data_iterator)
            step_data_time += time.time() - _data_start_t

            input_ids = data['input_ids'].to(DEVICE)
            labels = data['labels'].to(DEVICE)
            attention_mask = data['attention_mask'].to(DEVICE)
            num_tokens = data['num_tokens'].to(DEVICE)

            # todo: support sp split

            packed_ctx = packed_sequence(num_tokens, sp_mesh=sp_mesh)
            with packed_ctx:
                ctx = MessageHub.get_instance('packed_sequence')
                position_ids = ctx.get_info('position_ids')
                outputs = llm(
                    input_ids=input_ids,
                    labels=labels,
                    position_ids=position_ids,
                    attention_mask=attention_mask,
                )
                loss = outputs.loss
                aux_loss = outputs.aux_loss[0].detach()
                # if isinstance(outputs, MoeCausalLMOutputWithPast) and hasattr(outputs, 'aux_loss'):
                #     aux_loss = outputs.aux_loss.detach()
                # else:
                #     aux_loss = None
                
                avg_iter_loss = loss / iters_per_step
                avg_iter_loss.backward()
            
            step_consumed_tokens += num_tokens.sum() / sp_mesh.size()
            step_loss += avg_iter_loss.item()
        
        reduce_ep_grad(llm, ep_mesh.size())
        grad_norm = clip_grad_norm_(requried_grad_moe_params, requried_grad_non_moe_params, experts_fsdp_mesh, args.max_grad_norm, foreach=True)
        grad_norm = grad_norm.to_local() if isinstance(grad_norm, DTensor) else grad_norm
        optimizer.step()
        optimizer.zero_grad()

        step_time = time.time() - step_start_t
        eta = step_time * (total_steps - step)
        eta = timedelta(seconds=int(eta))
        tgs = int(step_consumed_tokens / step_time)
        max_memory = torch.cuda.max_memory_allocated()
        
        all_time = time.time() - start_train_t2
        tgs_end2end = int(step_consumed_tokens * (step - start_step + 1) / all_time)
        ####################################################################################################
        # all reduce loss
        step_loss_pre_rank = copy.deepcopy(step_loss)
        step_loss = torch.tensor(step_loss, device='cuda')
        dist.all_reduce(step_loss)
        step_loss = step_loss.item() / world_size
        if aux_loss is not None:
            aux_loss_reduced = copy.deepcopy(aux_loss)
            dist.all_reduce(aux_loss_reduced)
            aux_loss_reduced = aux_loss_reduced.item() / world_size
        else:
            aux_loss_reduced = None
        ####################################################################################################

        if is_interval(step, total_steps, args.log_interval):
            logger.info(f'[Train] (Step {step + 1}/{total_steps})  '
                        f'lr: {cur_lr:.6f}  loss: {step_loss_pre_rank:.3f}  aux_loss: {aux_loss if aux_loss is not None else None}   '
                        f'grad_norm: {grad_norm:.2f}  '
                        f'max_memory: {(max_memory / 1024 ** 3):.1f}GB  '
                        f'text_tokens: {step_consumed_tokens}  '
                        f'tgs: {tgs}  tgs_end2end: {tgs_end2end}  data_time: {step_data_time:.2f}s  '
                        f'time: {step_time:.2f}s  '
                        f'eta: {eta}')
            if rank == 0:
                vis_data = dict(
                    Step=step + 1,
                    lr=cur_lr,
                    loss=step_loss,
                    aux_loss=aux_loss_reduced,
                    max_memory=max_memory / 1024**3,
                    text_tokens=step_consumed_tokens.item(),
                    tgs=tgs,
                    data_time=step_data_time)
                with open(vis_data_file, 'a', encoding='utf-8') as f:
                    f.write(json.dumps(vis_data) + '\n')


if __name__ == '__main__':

    args = parse_args()
    main(args)
